{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CIFAR10: TrainDataLoaderIter and ValDataLoaderIter\n",
    "\n",
    "This example demonstrates the usage of `TrainDataLoaderIter` and `ValDataLoaderIter` with a ResNet56 on the Cifar10 dataset.\n",
    "\n",
    "\n",
    "`LRFinder.range_test()` assumes that the `DataLoader` objects passed in return something `input, label, *` where `input` is the tensor passed to the `model` and `label` is the tensor passed to the `criterion`. If this is not the case for your application then you'll probably need to use `TrainDataLoaderIter` and `ValDataLoaderIter`. These two classes essentially wrap around the `DataLoader` class for the training set and validation set, respectively, and allow you to customize how the batches are formed for each of them.\n",
    "\n",
    "\n",
    "Common use-cases for `TrainDataLoaderIter` and `ValDataLoaderIter`:\n",
    "1. Your `DataLoader`/`Dataset.__get_item__()` returns a `dict` or other containers that differ from the above (`input, label, *`)\n",
    "2. Your `Dataset.__get_item__()` doesn't perform all the data processing/handling/conversion needed for `input` and `label` to be ready to be passed in to the `model` and `criterion`, respectively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.datasets import CIFAR10\n",
    "import cifar10_resnet as rc10\n",
    "from PIL import Image\n",
    "\n",
    "try:\n",
    "    from torch_lr_finder import LRFinder, TrainDataLoaderIter, ValDataLoaderIter\n",
    "except ImportError:\n",
    "    # Run from source\n",
    "    import sys\n",
    "    sys.path.insert(0, '..')\n",
    "    from torch_lr_finder import LRFinder, TrainDataLoaderIter, ValDataLoaderIter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading CIFAR10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To demonstrate how `TrainDataLoaderIter` and `ValDataLoaderIter` can be used, we'll create a new CIFAR10 dataset that instead of returning `img, target` like [the one from torchvision](https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py#L118), returns a dictionary. We are essentially provoking use-case 1 from above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# All we need to do here is inherit from CIFAR10 and change it's __get_item__()\n",
    "# method to return a dictionary\n",
    "class CIFAR10WithDict(CIFAR10):\n",
    "    def __getitem__(self, index):\n",
    "        # IMPORTANT; this code works with torchvision 0.6.0; it's not guaranteed to\n",
    "        # work with other versions\n",
    "        img, target = self.data[index], self.targets[index]\n",
    "\n",
    "        # doing this so that it is consistent with all other datasets\n",
    "        # to return a PIL Image\n",
    "        img = Image.fromarray(img)\n",
    "\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "\n",
    "        if self.target_transform is not None:\n",
    "            target = self.target_transform(target)\n",
    "            \n",
    "        ret_dict = {}\n",
    "        ret_dict[\"img\"] = img\n",
    "        ret_dict[\"target\"] = target\n",
    "\n",
    "        return ret_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar_pwd = \"../data\"\n",
    "batch_size= 256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
    "\n",
    "trainset = CIFAR10WithDict(root=cifar_pwd, train=True, download=True, transform=transform)\n",
    "trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)\n",
    "\n",
    "testset = CIFAR10WithDict(root=cifar_pwd, train=False, download=True, transform=transform)\n",
    "testloader = DataLoader(testset, batch_size=batch_size * 2, shuffle=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = rc10.resnet56()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training loss (fastai)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-7, weight_decay=1e-2)\n",
    "lr_finder = LRFinder(model, optimizer, criterion, device=\"cuda\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try to do a `range_test` and see what happens:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c907007fbc1b42549cf7cd4dc720e3b3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "Your batch type not supported: <class 'dict'>. Please inherit from `TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine `_batch_make_inputs_labels` method.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-7-e7aa62ecded3>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mlr_finder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrange_test\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_lr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_iter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstep_mode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"exp\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36mrange_test\u001b[0;34m(self, train_loader, val_loader, start_lr, end_lr, num_iter, step_mode, smooth_f, diverge_th, accumulation_steps, non_blocking_transfer)\u001b[0m\n\u001b[1;32m    288\u001b[0m                 \u001b[0mtrain_iter\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    289\u001b[0m                 \u001b[0maccumulation_steps\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 290\u001b[0;31m                 \u001b[0mnon_blocking_transfer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnon_blocking_transfer\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    291\u001b[0m             )\n\u001b[1;32m    292\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mval_loader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36m_train_batch\u001b[0;34m(self, train_iter, accumulation_steps, non_blocking_transfer)\u001b[0m\n\u001b[1;32m    339\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    340\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maccumulation_steps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 341\u001b[0;31m             \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_iter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    342\u001b[0m             inputs, labels = self._move_to_device(\n\u001b[1;32m    343\u001b[0m                 \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnon_blocking_transfer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     56\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     57\u001b[0m             \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m             \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minputs_labels_from_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     59\u001b[0m         \u001b[0;32mexcept\u001b[0m \u001b[0mStopIteration\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     60\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_reset\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36minputs_labels_from_batch\u001b[0;34m(self, batch_data)\u001b[0m\n\u001b[1;32m     33\u001b[0m                 \u001b[0;34m\"Your batch type not supported: {}. Please inherit from \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     34\u001b[0m                 \u001b[0;34m\"`TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m                 \u001b[0;34m\"`_batch_make_inputs_labels` method.\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     36\u001b[0m             )\n\u001b[1;32m     37\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Your batch type not supported: <class 'dict'>. Please inherit from `TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine `_batch_make_inputs_labels` method."
     ]
    }
   ],
   "source": [
    "lr_finder.range_test(trainloader, end_lr=100, num_iter=100, step_mode=\"exp\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The problem is that `LRFinder` doesn't know how to build a batch using a dictionary. This is where `TrainDataLoaderIter` comes in:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomTrainIter(TrainDataLoaderIter):\n",
    "    def inputs_labels_from_batch(self, batch_data):\n",
    "        return batch_data[\"img\"], batch_data[\"target\"]\n",
    "    \n",
    "custom_train_iter = CustomTrainIter(trainloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`inputs_labels_from_batch()` gets the return of `CIFAR10WithDict.__get_item__()` and must return a tuple `input, label` or in this case `img, target`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3b266207edf64bf1a01331664717384e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stopping early, the loss has diverged\n",
      "Learning rate search finished. See the graph with {finder_name}.plot()\n"
     ]
    }
   ],
   "source": [
    "lr_finder.reset()\n",
    "lr_finder.range_test(custom_train_iter, end_lr=100, num_iter=100, step_mode=\"exp\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validation loss (Leslie N. Smith)\n",
    "\n",
    "We find something similar when using the validation loss test. If we try to run `range_test` with `trainloader` and `testloader` we get an error:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5e1f7e37a29648dbbfef1dababa98e01",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "ValueError",
     "evalue": "Your batch type not supported: <class 'dict'>. Please inherit from `TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine `_batch_make_inputs_labels` method.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-10-89b71b5d7108>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mlr_finder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mlr_finder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrange_test\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_loader\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtestloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_lr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_iter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstep_mode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"exp\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36mrange_test\u001b[0;34m(self, train_loader, val_loader, start_lr, end_lr, num_iter, step_mode, smooth_f, diverge_th, accumulation_steps, non_blocking_transfer)\u001b[0m\n\u001b[1;32m    288\u001b[0m                 \u001b[0mtrain_iter\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    289\u001b[0m                 \u001b[0maccumulation_steps\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 290\u001b[0;31m                 \u001b[0mnon_blocking_transfer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnon_blocking_transfer\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    291\u001b[0m             )\n\u001b[1;32m    292\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mval_loader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36m_train_batch\u001b[0;34m(self, train_iter, accumulation_steps, non_blocking_transfer)\u001b[0m\n\u001b[1;32m    339\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    340\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maccumulation_steps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 341\u001b[0;31m             \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_iter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    342\u001b[0m             inputs, labels = self._move_to_device(\n\u001b[1;32m    343\u001b[0m                 \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnon_blocking_transfer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     56\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     57\u001b[0m             \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m             \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minputs_labels_from_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     59\u001b[0m         \u001b[0;32mexcept\u001b[0m \u001b[0mStopIteration\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     60\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_reset\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36minputs_labels_from_batch\u001b[0;34m(self, batch_data)\u001b[0m\n\u001b[1;32m     33\u001b[0m                 \u001b[0;34m\"Your batch type not supported: {}. Please inherit from \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     34\u001b[0m                 \u001b[0;34m\"`TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m                 \u001b[0;34m\"`_batch_make_inputs_labels` method.\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     36\u001b[0m             )\n\u001b[1;32m     37\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Your batch type not supported: <class 'dict'>. Please inherit from `TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine `_batch_make_inputs_labels` method."
     ]
    }
   ],
   "source": [
    "lr_finder.reset()\n",
    "lr_finder.range_test(trainloader, val_loader=testloader, end_lr=100, num_iter=100, step_mode=\"exp\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Even if we replace `trainloader` by `custom_train_iter` we'll still get the error:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "25f4e7c42bc2481eb13cc0900015c808",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "ValueError",
     "evalue": "Your batch type not supported: <class 'dict'>. Please inherit from `TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine `_batch_make_inputs_labels` method.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-11-7db1954b6657>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mlr_finder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mlr_finder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrange_test\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcustom_train_iter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_loader\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtestloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_lr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_iter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstep_mode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"exp\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36mrange_test\u001b[0;34m(self, train_loader, val_loader, start_lr, end_lr, num_iter, step_mode, smooth_f, diverge_th, accumulation_steps, non_blocking_transfer)\u001b[0m\n\u001b[1;32m    292\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mval_loader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    293\u001b[0m                 loss = self._validate(\n\u001b[0;32m--> 294\u001b[0;31m                     \u001b[0mval_iter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking_transfer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnon_blocking_transfer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    295\u001b[0m                 )\n\u001b[1;32m    296\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36m_validate\u001b[0;34m(self, val_iter, non_blocking_transfer)\u001b[0m\n\u001b[1;32m    395\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    396\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 397\u001b[0;31m             \u001b[0;32mfor\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mval_iter\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    398\u001b[0m                 \u001b[0;31m# Move data to the correct device\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    399\u001b[0m                 inputs, labels = self._move_to_device(\n",
      "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     45\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__next__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m         \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minputs_labels_from_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     48\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36minputs_labels_from_batch\u001b[0;34m(self, batch_data)\u001b[0m\n\u001b[1;32m     33\u001b[0m                 \u001b[0;34m\"Your batch type not supported: {}. Please inherit from \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     34\u001b[0m                 \u001b[0;34m\"`TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m                 \u001b[0;34m\"`_batch_make_inputs_labels` method.\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     36\u001b[0m             )\n\u001b[1;32m     37\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Your batch type not supported: <class 'dict'>. Please inherit from `TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine `_batch_make_inputs_labels` method."
     ]
    }
   ],
   "source": [
    "lr_finder.reset()\n",
    "lr_finder.range_test(custom_train_iter, val_loader=testloader, end_lr=100, num_iter=100, step_mode=\"exp\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The solution is to wrap `testloader` in `ValDataLoaderIter` and overriding `inputs_labels_from_batch()`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomValIter(ValDataLoaderIter):\n",
    "    def inputs_labels_from_batch(self, batch_data):\n",
    "        return batch_data[\"img\"], batch_data[\"target\"]\n",
    "    \n",
    "custom_val_iter = CustomValIter(testloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bed059b7f2834752a52fa1f130cfbd59",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Learning rate search finished. See the graph with {finder_name}.plot()\n"
     ]
    }
   ],
   "source": [
    "lr_finder.reset()\n",
    "lr_finder.range_test(custom_train_iter, val_loader=custom_val_iter, end_lr=100, num_iter=100, step_mode=\"exp\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lr_finder_p3",
   "language": "python",
   "name": "lr_finder_p3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
